/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 * this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 * contributors may be used to endorse or promote products derived from this
 * software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <stdint.h>

#include <algorithm>
#include <cstdio>
#include <stdexcept>
#include <vector>

#include "cutlass_matmul.h"
#include "paddle/extension.h"
#include "utils.h"

__host__ __device__ Activation convert_activation(const int activation) {
  switch (activation) {
    case 0:
      return Activation::ReLU;
    case 1:
      return Activation::Exponential;
    case 2:
      return Activation::Sine;
    case 3:
      return Activation::Sigmoid;
    case 4:
      return Activation::Squareplus;
    case 5:
      return Activation::Softplus;
    case 6:
      return Activation::None;
    default:
      return Activation::None;
  }
}

template <typename T>
__host__ __device__ T div_round_up(T val, T divisor) {
  return (val + divisor - 1) / divisor;
}

void check_shmem_error(cudaError_t error) {
  if (error != cudaSuccess) {
    throw std::runtime_error{
        "FullyFusedMLP: insufficient shared memory available on the GPU. "
        "Reduce `n_neurons` or use `CutlassMLP` (better compatibility but "
        "slower) instead."};
  }
}

template <int WIDTH, int N_ITERS, typename OUT_T, bool BACKWARD = false>
__device__ void threadblock_layer(
    Activation activation, __half* __restrict__ act_shmem,
    const __half* __restrict__ weights_this_layer,
    OUT_T* __restrict__ out_intermediate_threadblock_this_layer,
    const OUT_T* __restrict__ activation_aux = nullptr) {
  // act_shmem contains the intermediate activations (shared memory) of the
  // thread block's chunk of the batch.
  //           Can be forward activations or backward activations, depending on
  //           caller.
  // weights_this_layer points to the weight matrix of the current layer.
  // out_intermediate_threadblock_this_layer points to the location where
  // intermediate activations produced by the thread block should be written to.
  //                  Can be nullptr if nothing should be written.
  // activation_aux points to additional arguments that the activation function
  // may depend on. Points to the hidden forward activations when computing
  // backward activations.

  constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
  constexpr uint32_t N_BLOCKS = WIDTH / 16;

  using namespace nvcuda;

  // If we're performing the backward pass, weights must be loaded in transposed
  // form, which is achieved by interpreting the memory in row_major instead of
  // col_major order.
  using weights_layout_t =
      std::conditional_t<BACKWARD, wmma::row_major, wmma::col_major>;

  // Fragments
  wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag;
  wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, weights_layout_t>
      weights_frag[N_BLOCKS];
  wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS];

  // Indices
  const uint32_t li = threadIdx.x;  // index in warp ("lane index")
  const uint32_t wi = threadIdx.y;  // index in block ("warp index")

  const uint32_t lane_offset = (8 * li) % WIDTH;
  const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;

  const uint32_t weights_col = 16 * wi;

  __syncthreads();

// Load N_BLOCKS chunks of weights from global memory into registers.
#pragma unroll
  for (uint32_t i = 0; i < N_BLOCKS; ++i) {
    if (BACKWARD) {
      // If we're performing the backward pass, additional index swizzling is
      // needed to load the weights in transposed form.
      wmma::load_matrix_sync(weights_frag[i],
                             weights_this_layer + 16 * i * WIDTH + weights_col,
                             WIDTH);
    } else {
      wmma::load_matrix_sync(weights_frag[i],
                             weights_this_layer + 16 * i + weights_col * WIDTH,
                             WIDTH);
    }
  }

#pragma unroll
  for (int l = 0; l < N_ITERS; ++l) {
    wmma::fill_fragment(result_frag[l], 0.0f);

#pragma unroll
    for (uint32_t i = 0; i < N_BLOCKS; ++i) {
      // Load a chunk of intermediate activations from shared memory and
      // multiply with chunk of weights
      wmma::load_matrix_sync(act_frag,
                             act_shmem + 16 * i + (16 * l) * (WIDTH + SKEW),
                             WIDTH + SKEW);
      wmma::mma_sync(result_frag[l], act_frag, weights_frag[i], result_frag[l]);
    }

    // Activation
    if (BACKWARD) {
      // Load the temporary forward matrix for the relu transfer
      wmma::load_matrix_sync(
          act_frag, activation_aux + weights_col + l * 16 * WIDTH, WIDTH);
      warp_activation_backward<__half>(activation, result_frag[l], act_frag,
                                       result_frag[l]);
    } else {
      warp_activation<__half>(activation, result_frag[l], result_frag[l]);
    }
  }

  __syncthreads();

#pragma unroll
  for (int l = 0; l < N_ITERS; ++l) {
    wmma::store_matrix_sync(act_shmem + weights_col + l * 16 * (WIDTH + SKEW),
                            result_frag[l], WIDTH + SKEW, wmma::mem_row_major);
  }

  if (out_intermediate_threadblock_this_layer != nullptr) {
    __syncthreads();

#pragma unroll
    for (int l = 0; l < N_ITERS; ++l) {
      *(int4*)&out_intermediate_threadblock_this_layer[lane_offset +
                                                       (row + 16 * l) * WIDTH] =
          *(int4*)&act_shmem[lane_offset + (row + 16 * l) * (WIDTH + SKEW)];
    }
  }
}

template <int WIDTH, int N_ITERS>
__device__ void threadblock_load_input_static(
    __half* __restrict__ act_shmem,
    const __half* __restrict__ input_threadblock) {
  // act_shmem will be filled by the thread block's chunk of input_threadblock

  constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;

  // Indices
  const uint32_t li = threadIdx.x;  // index in warp ("lane index")
  const uint32_t wi = threadIdx.y;  // index in block ("warp index")

  const uint32_t lane_offset = (8 * li) % WIDTH;
  const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;

#pragma unroll
  for (int i = 0; i < N_ITERS; ++i) {
    *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)] =
        *(int4*)&input_threadblock[lane_offset + (row + 16 * i) * WIDTH];
  }
}

template <int WIDTH, int N_ITERS, typename OUT_T>
__device__ void threadblock_input_layer_forward_dynamic(
    Activation activation, __half* __restrict__ act_shmem,
    const __half* __restrict__ input_threadblock,
    const __half* __restrict__ weights_this_layer,
    OUT_T* __restrict__ out_intermediate_threadblock_this_layer,
    const uint32_t in_width) {
  // act_shmem contains the intermediate activations (shared memory) of the
  // thread block's chunk of the batch input_threadblock points to the thread
  // block's chunk of the input batch in global memory weights_this_layer points
  // to the weight matrix of the current layer
  // out_intermediate_threadblock_this_layer points to the location where
  // intermediate activations produced by the thread block should be written to.
  //                  Can be nullptr if nothing should be written.
  // in_width is the dynamic width of the input layer

  constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
  constexpr uint32_t INPUT_SKEW = 8;
  constexpr uint32_t N_BLOCKS = WIDTH / 16;

  using namespace nvcuda;

  // Fragments
  wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag;
  wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major>
      weights_frag;
  wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS];

  // Indices
  const uint32_t li = threadIdx.x;  // index in warp ("lane index")
  const uint32_t wi = threadIdx.y;  // index in block ("warp index")

  const uint32_t lane_offset = (8 * li) % WIDTH;
  const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;

  const uint32_t weights_col = 16 * wi;

  __half* __restrict__ weights_shmem = act_shmem + 16 * (in_width + INPUT_SKEW);

  // Load input weight matrix (fits completely into shared memory)
  // Each thread can load 8 fp16 elements (16 bytes) at once; we have N_BLOCKS
  // warps
  const uint32_t n_elems_per_load = N_BLOCKS * 32 * 8;
  const uint32_t thread_elem_idx = (li + wi * 32) * 8;

  const uint32_t n_elems_b = WIDTH * in_width;

#pragma unroll
  for (uint32_t idx = thread_elem_idx; idx < n_elems_b;
       idx += n_elems_per_load) {
    const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW;
    *(int4*)&weights_shmem[idx_skewed] = *(int4*)&weights_this_layer[idx];
  }

  const uint32_t n_tensor_ops = in_width / 16;

#pragma unroll
  for (int l = 0; l < N_ITERS; ++l) {
    // Load chunk of inputs into shmem.
    // This is faster than loading it from gmem directly, even though it is
    // only used once. (Possibly due to latency hiding through staging.)
    const uint32_t n_elems_a = 16 * in_width;

#pragma unroll
    for (uint32_t idx = thread_elem_idx; idx < n_elems_a;
         idx += n_elems_per_load) {
      const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW;
      *(int4*)&act_shmem[idx_skewed] =
          *(int4*)&input_threadblock[l * n_elems_a + idx];
    }

    __syncthreads();

    wmma::fill_fragment(result_frag[l], 0.0f);
#pragma unroll
    for (uint32_t i = 0; i < n_tensor_ops; ++i) {
      // Load chunk of inputs and weights from shared memory and multiply them
      wmma::load_matrix_sync(act_frag, act_shmem + 16 * i,
                             in_width + INPUT_SKEW);
      wmma::load_matrix_sync(
          weights_frag,
          weights_shmem + 16 * i + weights_col * (in_width + INPUT_SKEW),
          in_width + INPUT_SKEW);
      wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]);
    }

    __syncthreads();

    warp_activation<__half>(activation, result_frag[l], result_frag[l]);
  }

#pragma unroll
  for (int l = 0; l < N_ITERS; ++l) {
    wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW),
                            result_frag[l], WIDTH + SKEW, wmma::mem_row_major);
  }

  if (out_intermediate_threadblock_this_layer != nullptr) {
    __syncthreads();

#pragma unroll
    for (int i = 0; i < N_ITERS; ++i) {
      *(int4*)&out_intermediate_threadblock_this_layer[lane_offset +
                                                       (row + 16 * i) * WIDTH] =
          *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)];
    }
  }
}

template <int WIDTH, int N_ITERS, typename OUT_T>
__device__ void threadblock_last_layer_forward(
    Activation activation, __half* __restrict__ act_shmem,
    const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out,
    const uint32_t batch_size, const nvcuda::wmma::layout_t output_layout) {
  // act_shmem contains the intermediate activations (shared memory) of the
  // thread block's chunk of the batch weights_this_layer points to the weight
  // matrix of the current layer out points to the location where the result
  // produced by the thread block should be written to.
  //   Can be nullptr if nothing should be written.

  constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
  constexpr uint32_t N_BLOCKS = WIDTH / 16;

  using namespace nvcuda;

  // Fragments
  wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag;
  wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major>
      weights_frag[N_BLOCKS];
  wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag;

  // Indices
  const uint32_t li = threadIdx.x;  // index in warp ("lane index")
  const uint32_t wi = threadIdx.y;  // index in block ("warp index")

  __half* __restrict__ weights_shmem =
      act_shmem + N_ITERS * 16 * (WIDTH + SKEW);

  const uint32_t weights_row = (8 * li) % WIDTH;
  const uint32_t weights_col = (8 * li + 8 * 32 * wi) / WIDTH;

  // Load weight matrix into shared memory for the last multiplication.
  // Loading into shared memory as opposed to directly into registers is faster
  // because unlike in the previous layers, each warp uses the same entries of
  // the weight matrix.
  *(int4*)&weights_shmem[weights_row + weights_col * (WIDTH + SKEW)] =
      *(int4*)&weights_this_layer[weights_row + weights_col * WIDTH];

  __syncthreads();

#pragma unroll
  for (uint32_t i = 0; i < N_BLOCKS; ++i)
    wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i,
                           WIDTH + SKEW);

  // Perform last layer by parallelizing over iters
  for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) {
    wmma::fill_fragment(result_frag, 0.0f);

#pragma unroll
    for (uint32_t i = 0; i < N_BLOCKS; ++i) {
      // Load a chunk of intermediate activations from shared memory and
      // multiply with chunk of the weight matrix
      wmma::load_matrix_sync(act_frag,
                             act_shmem + 16 * i + (16 * idx) * (WIDTH + SKEW),
                             WIDTH + SKEW);
      wmma::mma_sync(result_frag, act_frag, weights_frag[i], result_frag);
    }

    warp_activation<__half>(activation, result_frag, result_frag);

    if (output_layout == wmma::mem_row_major) {
      wmma::store_matrix_sync(out + idx * 16 * 16, result_frag, 16,
                              output_layout);
    } else {
      wmma::store_matrix_sync(out + idx * 16, result_frag, batch_size,
                              output_layout);
    }
  }
}

template <int WIDTH, int N_ITERS>
__device__ void threadblock_write_output_static(
    const __half* __restrict__ act_shmem,
    __half* __restrict__ output_threadblock) {
  // output_threadblock will be filled by the thread block's act_shmem

  constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;

  // Indices
  const uint32_t li = threadIdx.x;  // index in warp ("lane index")
  const uint32_t wi = threadIdx.y;  // index in block ("warp index")

  const uint32_t lane_offset = (8 * li) % WIDTH;
  const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;

  __syncthreads();

#pragma unroll
  for (int i = 0; i < N_ITERS; ++i) {
    *(int4*)&output_threadblock[lane_offset + (row + 16 * i) * WIDTH] =
        *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)];
  }
}

template <int WIDTH, int N_ITERS, typename OUT_T, bool INFERENCE>
__global__ void kernel_mlp_fused(
    const Activation activation, const Activation output_activation,
    const __half* __restrict__ input, const __half* __restrict__ weights,
    OUT_T* __restrict__ out_intermediate, OUT_T* __restrict__ out,
    const uint32_t batch_size, const uint32_t in_width,
    const uint32_t out_width, const uint32_t n_hidden_matmuls,
    const nvcuda::wmma::layout_t output_layout = nvcuda::wmma::mem_row_major) {
  // `input` points to the input matrix. Can be any width.
  // `weights` points to the weight matrices (contiguous in memory).
  // `out_intermediate` points to the memory where intermediate activations
  // should be written. When performing inference, a value of nullptr is
  // expected (intermediate results are not written). `out` points to the memory
  // where the network output should be written. (Output width is assumed to be
  // 16 neurons.)

  // if (threadIdx.x == 0) printf("[forward] call kernel_mlp_fused\n");
  // if (threadIdx.x == 0) printf("[forward] inputs=%f\n", (float)input[0]);
  // if (threadIdx.x == 0) printf("[forward] weights=%f\n", (float)weights[0]);

  // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n",
  // (float)out_intermediate[0]);

  // Shared memory contains the intermediate activations of blockDim.y*16
  // elements. In some cases, it also contains the weight matrix for the first
  // and last layer.
  extern __shared__ __half shmem[];
  __half* act_shmem = shmem;

  // Each block computes exactly one 16-element chunk of the batch.
  const uint32_t elem_idx = 16 * blockIdx.x * N_ITERS;

  // First layer
  if (in_width == WIDTH) {
    // If the input has the same width as the network, we can simply use the
    // network's regular layer routine (with static size) instead of using the
    // slower dynamic input layer routine.
    threadblock_load_input_static<WIDTH, N_ITERS>(act_shmem,
                                                  input + elem_idx * WIDTH);
    threadblock_layer<WIDTH, N_ITERS, OUT_T>(
        activation, act_shmem, weights,
        !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr);
  } else {
    threadblock_input_layer_forward_dynamic<WIDTH, N_ITERS, OUT_T>(
        activation, act_shmem, input + elem_idx * in_width, weights,
        !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width);
  }

  // if (threadIdx.x == 0) printf("[forward] kernel_mlp_fused: passed first
  // layer\n");
  // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n",
  // (float)out_intermediate[0]);

  const uint32_t first_layer_size = WIDTH * in_width;
  const uint32_t layer_stride = WIDTH * WIDTH;
  const uint32_t output_stride = WIDTH * batch_size;

  // Hidden layers
  for (uint32_t k = 0; k < n_hidden_matmuls; ++k) {
    threadblock_layer<WIDTH, N_ITERS, OUT_T>(
        activation, act_shmem, weights + first_layer_size + layer_stride * k,
        !INFERENCE
            ? (out_intermediate + output_stride * (k + 1) + elem_idx * WIDTH)
            : nullptr);
    // if (threadIdx.x == 0) printf("[forward] kernel_mlp_fused: passed %d
    // layer\n", k + 1);
    // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n",
    // (float)out_intermediate[0]);
  }

  if (out_width > 16) {
    // In the forward pass, intermediate activations are already written out.
    if (INFERENCE) {
      threadblock_write_output_static<WIDTH, N_ITERS>(
          act_shmem, out_intermediate + elem_idx * WIDTH);
    }
  } else if (out) {
    // Last layer
    if (output_layout == nvcuda::wmma::mem_row_major) {
      // printf("[last layer] RM write to out %d\n", elem_idx * 16);
      // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n",
      // (float)out_intermediate[0]);
      threadblock_last_layer_forward<WIDTH, N_ITERS, OUT_T>(
          output_activation, act_shmem,
          weights + first_layer_size + layer_stride * n_hidden_matmuls,
          out + elem_idx * 16, 16, output_layout);
      // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n",
      // (float)out_intermediate[0]);
    } else {
      // printf("[last layer] CM write to out %d\n", elem_idx);
      // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n",
      // (float)out_intermediate[0]);
      threadblock_last_layer_forward<WIDTH, N_ITERS, OUT_T>(
          output_activation, act_shmem,
          weights + first_layer_size + layer_stride * n_hidden_matmuls,
          out + elem_idx, batch_size, output_layout);
      // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n",
      // (float)out_intermediate[0]);
    }
  }
}

template <int WIDTH, int N_ITERS, typename OUTPUT_LAYOUT>
__global__ void kernel_mlp_fused_backward(
    const Activation activation, const __half* __restrict__ dL_doutput,
    const __half* __restrict__ weights, __half* __restrict__ out_intermediate,
    const __half* __restrict__ forward, __half* __restrict__ dL_dinput,
    const __half* __restrict__ weights_first_layer, const uint32_t batch_size,
    const uint32_t out_width, const uint32_t n_hidden_matmuls) {
  // `dL_doutput` points to the input matrix of the backward pass, i.e. the loss
  // gradients. Assumed to be 16 neurons wide. `weights` points to the weight
  // matrices (contiguous in memory). `out_intermediate` points to the memory
  // where backpropagated activation gradients should be written. `forward`
  // points to the memory where the intermediate activations of the forward pass
  // are located. (needed for activation backprop)

  constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;

  // Indices
  const uint32_t li = threadIdx.x;  // index in warp ("lane index")
  const uint32_t wi = threadIdx.y;  // index in block ("warp index")
  const uint32_t bi = blockIdx.x;   // block index

  // Shared memory contains the intermediate activations of blockDim.y*16
  // elements. A skew is applied to the matrix storage to avoid bank conflicts.
  extern __shared__ __half shmem[];
  __half* act_shmem = shmem;

  const uint32_t lane_offset = (8 * li) % WIDTH;
  const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;

  // Multipying one 16-row chunk of intermediate activations with the weight
  // matrix requires all warps of the block. Thus, each block computes exactly
  // one 16-row chunk of the next layer's intermediate activations.
  const uint32_t elem_idx_base = 16 * bi * N_ITERS;
  const uint32_t elem_idx = elem_idx_base;

  const uint32_t layer_stride = WIDTH * WIDTH;
  const uint32_t output_stride = WIDTH * batch_size;

  // Backprop through last layer
  if (out_width <= 16) {
    using namespace nvcuda;

    // Fragments in registers
    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, OUTPUT_LAYOUT> act_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major>
        weights_frag;
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> result_frag[N_ITERS];

    // Load the relevant chunk of the last layer's weight matrix from global
    // memory into registers
    const uint32_t weights_col = 16 * wi;

    wmma::load_matrix_sync(
        weights_frag, weights + layer_stride * n_hidden_matmuls + weights_col,
        WIDTH);

#pragma unroll
    for (int l = 0; l < N_ITERS; ++l) {
      wmma::fill_fragment(result_frag[l], 0.0f);

      // Load a chunk of output gradients from shared memory and multiply with
      // previously loaded weights
      if (std::is_same<OUTPUT_LAYOUT, wmma::row_major>::value) {
        wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * l) * 16,
                               16);
      } else {
        wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * l),
                               batch_size);
      }

      // NOTE: activation transfer of the _output_ activation is expected to be
      // done _prior_ to calling this kernel
      //       in a separate pass, because the tranfered activation gradient is
      //       also needed to compute the weight gradient of the last weight
      //       matrix (see backward()).
      wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]);

      // Load the temporary forward matrix for the relu transfer
      wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major>
          forward_frag;
      wmma::load_matrix_sync(forward_frag,
                             forward + output_stride * n_hidden_matmuls +
                                 weights_col + (elem_idx + l * 16) * WIDTH,
                             WIDTH);

      warp_activation_backward<__half>(activation, result_frag[l], forward_frag,
                                       result_frag[l]);
    }

    __syncthreads();

#pragma unroll
    for (int l = 0; l < N_ITERS; ++l) {
      wmma::store_matrix_sync(
          act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), result_frag[l],
          WIDTH + SKEW, wmma::mem_row_major);
    }

    __syncthreads();

#pragma unroll
    for (int i = 0; i < N_ITERS; ++i) {
      *(int4*)&out_intermediate[lane_offset +
                                (row + elem_idx + i * 16) * WIDTH] =
          *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)];
    }
  } else {
    // If the output width is larger than 16, we will have used CUTLASS for
    // backpropping through the last layer. Load the resulting gradients.
    threadblock_load_input_static<WIDTH, N_ITERS>(
        act_shmem, out_intermediate + elem_idx * WIDTH);
  }

  // Backprop through hidden layers
  for (uint32_t k = 0; k < n_hidden_matmuls; ++k) {
    threadblock_layer<WIDTH, N_ITERS, __half, true>(
        activation, act_shmem,
        weights + layer_stride * (n_hidden_matmuls - k - 1),
        out_intermediate + output_stride * (k + 1) + elem_idx_base * WIDTH,
        forward + output_stride * (n_hidden_matmuls - k - 1) +
            elem_idx_base * WIDTH);
  }

  // Compute loss gradients w.r.t. input if desired.
  // THIS CODE ASSUMES THAT THE INPUT WIDTH IS THE SAME AS THE NETWORK WIDTH.
  // DON'T PASS A NON-NULL dL_dinput IF THIS REQUIREMENT IS NOT MET.
  if (dL_dinput != nullptr) {
    threadblock_layer<WIDTH, N_ITERS, __half, true>(
        Activation::None, act_shmem, weights_first_layer,
        dL_dinput + elem_idx_base * WIDTH);
  }
}

//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

template <uint32_t WIDTH, bool INFERENCE>  // WIDTH is hidden_dim
void ffmlp_forward_cuda(const __half* inputs, const __half* weights,
                        const uint32_t B, const uint32_t input_dim,
                        const uint32_t output_dim, const uint32_t num_layers,
                        const Activation activation,
                        const Activation output_activation,
                        __half* forward_buffer, __half* outputs) {
  constexpr uint32_t SKEW =
      WIDTH % 16 == 0 ? 8 : 0;  // <- always going to be 8 as we only support
                                // multiple-of-16 widths
  constexpr uint32_t INPUT_SKEW = 8;  // <- likewise with inputs
  constexpr uint32_t N_BLOCK_ROWS = WIDTH / 16;

  const int N_ITERS = WIDTH >= 256 ? 2 : 8;

  const dim3 threads = {32u, N_BLOCK_ROWS,
                        1};  // 32 threads = 1 warp, N_BLOCK_ROWS warps per
                             // block for 16 rows, up to 2x 8 warps can share
                             // input (does not help vs. 1)

  uint32_t n_elems_per_block = 16 * N_ITERS;
  uint32_t n_blocks = div_round_up(B, n_elems_per_block);

  size_t shmem_size =
      sizeof(__half) * (16 + 16 * N_ITERS) *
      (WIDTH + SKEW);  // 16*WIDTH rows of weights (for the last layer; others
                       // are in registers only) + 16*WIDTH*BLOCK_DIM_Z*N_ITERS
                       // rows of intermediate activations

  // If the input width is dynamic, the input weight matrix as well as part of
  // the input will live in extra shared memory
  if (input_dim != WIDTH) {
    shmem_size = std::max(
        shmem_size, sizeof(__half) * (WIDTH + 16) * (input_dim + INPUT_SKEW));
  }

  // printf("[ffmlp_forward_cuda] shmem size = %d\n", shmem_size);

  const dim3 blocks = {n_blocks, 1u, 1u};

  check_shmem_error(cudaFuncSetAttribute(
      kernel_mlp_fused<WIDTH, N_ITERS, __half, INFERENCE>,
      cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem_size));

  kernel_mlp_fused<WIDTH, N_ITERS, __half, INFERENCE>
      <<<blocks, threads, shmem_size, 0>>>(
          activation, output_activation,
          inputs,          // CM
          weights,         // RM
          forward_buffer,  // CM
          outputs,         // CM
          B, input_dim, output_dim, num_layers - 1,
          nvcuda::wmma::mem_row_major  // reversed outputs's layout
      );
}

template <uint32_t WIDTH>  // WIDTH is hidden_dim
void ffmlp_backward_cuda(const __half* grad, const __half* weights,
                         const uint32_t B, const uint32_t input_dim,
                         const uint32_t output_dim, const uint32_t num_layers,
                         const Activation activation,
                         const __half* forward_buffer, __half* backward_buffer,
                         __half* grad_inputs) {
  // locate
  const __half* weights_first = weights;
  const __half* weights_second = weights + input_dim * WIDTH;

  constexpr uint32_t SKEW =
      WIDTH % 16 == 0 ? 8 : 0;  // <- always going to be 8 as we only support
                                // multiple-of-16 widths
  constexpr uint32_t N_BLOCKS = WIDTH / 16;

  const int N_ITERS = WIDTH >= 256 ? 2 : 8;

  const dim3 threads = {
      32u, N_BLOCKS,
      1};  // 32 threads = 1 warp, 8 warps per block for 16 rows, up to 2x 8
           // warps can share input (does not help vs. 1)

  uint32_t n_elems_per_block = 16 * N_ITERS;
  uint32_t n_blocks = div_round_up(B, n_elems_per_block);

  int shmem_size =
      sizeof(__half) *
      ((16 * N_ITERS) *
       (WIDTH +
        SKEW));  // WIDTH rows of input and 16 * threads.z rows of weights

  const dim3 blocks = {n_blocks, 1u, 1u};

  // The kernels operate with transposed layouts compared with the MLP code
  check_shmem_error(cudaFuncSetAttribute(
      kernel_mlp_fused_backward<WIDTH, N_ITERS, nvcuda::wmma::row_major>,
      cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));

  kernel_mlp_fused_backward<WIDTH, N_ITERS, nvcuda::wmma::row_major>
      <<<blocks, threads, shmem_size, 0>>>(activation,
                                           grad,             // CM
                                           weights_second,   // RM
                                           backward_buffer,  // CM
                                           forward_buffer,   // CM
                                           grad_inputs,      // CM
                                           weights_first,    // RM
                                           B, output_dim, num_layers - 1);
}

// inputs: col-major [input_dim, B]
// weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim *
// (num_layers - 1)] + [output_dim * hidden_dim] forward_buffer: col-major
// [num_layers, hidden_dim, B] outputs: col-major [output_dim, B]
std::vector<paddle::Tensor> ffmlp_forward(const paddle::Tensor& inputs,
                                          const paddle::Tensor& weights,
                                          const int64_t output_dim,
                                          const int64_t hidden_dim,
                                          const int64_t num_layers,
                                          const int activation_,
                                          const int output_activation_) {
  CHECK_CUDA(inputs);
  CHECK_IS_HALF(inputs);
  PD_CHECK(inputs.shape().size() == 2);

  CHECK_CUDA(weights);
  CHECK_IS_HALF(weights);

  Activation activation = convert_activation(activation_);
  Activation output_activation = convert_activation(output_activation_);

  const int64_t B = inputs.shape()[0];
  const int64_t input_dim = inputs.shape()[1];

  auto inputs_ptr =
      reinterpret_cast<const __half*>(inputs.data<paddle::float16>());
  auto weights_ptr =
      reinterpret_cast<const __half*>(weights.data<paddle::float16>());

  auto forward_buffer =
      paddle::empty({num_layers, B, hidden_dim}, paddle::DataType::FLOAT16,
                    paddle::GPUPlace());
  auto outputs = paddle::empty({B, output_dim}, paddle::DataType::FLOAT16,
                               paddle::GPUPlace());

  auto forward_buffer_ptr =
      reinterpret_cast<__half*>(forward_buffer.data<paddle::float16>());
  auto outputs_ptr = reinterpret_cast<__half*>(outputs.data<paddle::float16>());

  switch (hidden_dim) {
    case 16:
      ffmlp_forward_cuda<16, false>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, forward_buffer_ptr, outputs_ptr);
      break;
    case 32:
      ffmlp_forward_cuda<32, false>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, forward_buffer_ptr, outputs_ptr);
      break;
    case 64:
      ffmlp_forward_cuda<64, false>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, forward_buffer_ptr, outputs_ptr);
      break;
    case 128:
      ffmlp_forward_cuda<128, false>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, forward_buffer_ptr, outputs_ptr);
      break;
    case 256:
      ffmlp_forward_cuda<256, false>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, forward_buffer_ptr, outputs_ptr);
      break;
    default:
      throw std::runtime_error{"hidden_dim should in [16, 32, 64, 128, 256]"};
  }

  // for output_dim > 16
  if (output_dim > 16) {
    fc_multiply<LastLayer, true, false, false>(
        0, output_dim, hidden_dim, B,
        (weights_ptr + hidden_dim * input_dim +
         (num_layers - 1) * hidden_dim *
             hidden_dim),  // row-major, [output_dim, hidden_dim]
        (forward_buffer_ptr +
         (num_layers - 1) * hidden_dim * B),  // col-major [hidden_dim, B]
        outputs_ptr,                          // col-major [outupt_dim, B]
        output_activation);
  }

  return {outputs, forward_buffer};
}

std::vector<paddle::Tensor> ffmlp_inference(const paddle::Tensor& inputs,
                                            const paddle::Tensor& weights,
                                            const int64_t output_dim,
                                            const int64_t hidden_dim,
                                            const int64_t num_layers,
                                            const int activation_,
                                            const int output_activation_) {
  CHECK_CUDA(inputs);
  CHECK_IS_HALF(inputs);
  PD_CHECK(inputs.shape().size() == 2);

  CHECK_CUDA(weights);
  CHECK_IS_HALF(weights);

  const int64_t B = inputs.shape()[0];
  const int64_t input_dim = inputs.shape()[1];

  Activation activation = convert_activation(activation_);
  Activation output_activation = convert_activation(output_activation_);

  auto inputs_ptr =
      reinterpret_cast<const __half*>(inputs.data<paddle::float16>());
  auto weights_ptr =
      reinterpret_cast<const __half*>(weights.data<paddle::float16>());

  auto inference_buffer = paddle::empty(
      {B, hidden_dim}, paddle::DataType::FLOAT16, paddle::GPUPlace());
  auto outputs = paddle::empty({B, output_dim}, paddle::DataType::FLOAT16,
                               paddle::GPUPlace());

  auto inference_buffer_ptr =
      reinterpret_cast<__half*>(inference_buffer.data<paddle::float16>());
  auto outputs_ptr = reinterpret_cast<__half*>(outputs.data<paddle::float16>());

  switch (hidden_dim) {
    case 16:
      ffmlp_forward_cuda<16, true>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, inference_buffer_ptr, outputs_ptr);
      break;
    case 32:
      ffmlp_forward_cuda<32, true>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, inference_buffer_ptr, outputs_ptr);
      break;
    case 64:
      ffmlp_forward_cuda<64, true>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, inference_buffer_ptr, outputs_ptr);
      break;
    case 128:
      ffmlp_forward_cuda<128, true>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, inference_buffer_ptr, outputs_ptr);
      break;
    case 256:
      ffmlp_forward_cuda<256, true>(
          inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers,
          activation, output_activation, inference_buffer_ptr, outputs_ptr);
      break;
    default:
      throw std::runtime_error{"hidden_dim should in [16, 32, 64, 128, 256]"};
  }

  // for output_dim > 16
  if (output_dim > 16) {
    fc_multiply<LastLayer, true, false, false>(
        0, output_dim, hidden_dim, B,
        (weights_ptr + hidden_dim * input_dim +
         (num_layers - 1) * hidden_dim *
             hidden_dim),      // row-major, [output_dim, hidden_dim]
        inference_buffer_ptr,  // col-major [hidden_dim, B]
        outputs_ptr,           // col-major [outupt_dim, B]
        output_activation);
  }

  return {outputs};
}

inline std::vector<cudaStream_t>& streams_splitk() {
  static std::vector<cudaStream_t> res;
  return res;
}

inline std::vector<cudaEvent_t>& events_splitk() {
  static std::vector<cudaEvent_t> res;
  return res;
}

void allocate_splitk(size_t size) {
  auto& streams = streams_splitk();
  auto& events = events_splitk();
  streams.resize(size);
  events.resize(size);
  for (size_t i = 0; i < size; i++) {
    CUDA_CHECK_THROW(cudaStreamCreate(&streams[i]));
    CUDA_CHECK_THROW(cudaEventCreate(&events[i]));
  }
}

void free_splitk() {
  auto& streams = streams_splitk();
  auto& events = events_splitk();
  for (size_t i = 0; i < streams.size(); i++) {
    cutlass_free_workspace(streams[i]);
    CUDA_CHECK_PRINT(cudaStreamDestroy(streams[i]));
    CUDA_CHECK_PRINT(cudaEventDestroy(events[i]));
  }
}

// grad: col-major [output_dim, B]
// inputs: col-major [input_dim, B]
// weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim *
// (num_layers - 1)] + [output_dim * hidden_dim] forward_buffer: col-major
// [num_layers, hidden_dim, B] backward_buffer: col-major [num_layers,
// hidden_dim, B] grad_inputs: col-major [input_dim, B] grad_weights: row-major
// [hidden_dim * input_dim] + [hidden_dim * hidden_dim * (num_layers - 1)] +
// [output_dim * hidden_dim]
std::vector<paddle::Tensor> ffmlp_backward(
    const paddle::Tensor& inputs, const paddle::Tensor& weights,
    const paddle::Tensor& forward_buffer, const paddle::Tensor& grad,
    const int64_t output_dim, const int64_t hidden_dim,
    const int64_t num_layers, const int activation_,
    const int output_activation_) {
  CHECK_CUDA(inputs);
  CHECK_IS_HALF(inputs);

  CHECK_CUDA(weights);
  CHECK_IS_HALF(weights);

  CHECK_CUDA(forward_buffer);
  CHECK_IS_HALF(forward_buffer);

  CHECK_CUDA(grad);
  CHECK_IS_HALF(grad);

  const int64_t B = inputs.shape()[0];
  const int64_t input_dim = inputs.shape()[1];

  Activation activation = convert_activation(activation_);
  Activation output_activation = convert_activation(output_activation_);

  // activation_backward_output_gpu (I gonna discard output_activation ...)

  int split_k_factor = B / std::min((int64_t)(1 << 12), B);

  uint32_t forward_index = num_layers - 1;
  uint32_t backward_index = 0;

  auto backward_buffer =
      paddle::full({num_layers, B, hidden_dim}, 0.0, paddle::DataType::FLOAT16,
                   paddle::GPUPlace());
  auto backward_buffer_ptr =
      reinterpret_cast<__half*>(backward_buffer.data<paddle::float16>());

  auto forward_buffer_ptr =
      reinterpret_cast<const __half*>(forward_buffer.data<paddle::float16>());
  auto grad_ptr = reinterpret_cast<const __half*>(grad.data<paddle::float16>());
  auto inputs_ptr =
      reinterpret_cast<const __half*>(inputs.data<paddle::float16>());
  auto weights_ptr =
      reinterpret_cast<const __half*>(weights.data<paddle::float16>());

  auto grad_weights = paddle::experimental::full_like(weights, 0.0);
  auto grad_weights_ptr =
      reinterpret_cast<__half*>(grad_weights.data<paddle::float16>());

  auto grad_inputs = paddle::experimental::full_like(inputs, 0.0);
  auto grad_inputs_ptr =
      reinterpret_cast<__half*>(grad_inputs.data<paddle::float16>());

  auto grad_inputs_fused_ptr =
      input_dim == hidden_dim ? grad_inputs_ptr : nullptr;

  // calc output layer, grad_weights
  cudaEventRecord(events_splitk().at(backward_index), 0);
  cudaStreamWaitEvent(streams_splitk().at(backward_index),
                      events_splitk().at(backward_index), 0);

  fc_multiply_split_k<LastLayerK, false, true, true>(
      streams_splitk().at(backward_index), output_dim, B, hidden_dim,
      grad_ptr,  // col-major, [output_dim, B]
      (forward_buffer_ptr +
       forward_index * hidden_dim * B),  // row-major, [B, hidden_dim]
      (grad_weights_ptr + hidden_dim * input_dim +
       (num_layers - 1) * hidden_dim *
           hidden_dim),  // row-major, [output_dim, hidden_dim]
      split_k_factor);

  cudaEventRecord(events_splitk().at(backward_index),
                  streams_splitk().at(backward_index));

  // prepare the last backward_buffer if output_dim > 16
  if (output_dim > 16) {
    fc_multiply<FullLayer, false, false, false>(
        0, hidden_dim, output_dim, B,
        (grad_weights_ptr + hidden_dim * input_dim +
         (num_layers - 1) * hidden_dim *
             hidden_dim),  // col-major, [hidden_dim, output_dim]
        grad_ptr,          // col-major, [output_dim, B]
        (forward_buffer_ptr +
         forward_index * hidden_dim * B),  // col-major, [hidden_dim, B]
        (backward_buffer_ptr +
         backward_index * hidden_dim * B),  // col-major [hidden_dim, B]
        activation, true);
  }

  // prepare backward_buffer
  // calc grad_inputs if input_dim == hidden_dim
  switch (hidden_dim) {
    case 16:
      ffmlp_backward_cuda<16>(grad_ptr, weights_ptr, B, input_dim, output_dim,
                              num_layers, activation, forward_buffer_ptr,
                              backward_buffer_ptr, grad_inputs_fused_ptr);
      break;
    case 32:
      ffmlp_backward_cuda<32>(grad_ptr, weights_ptr, B, input_dim, output_dim,
                              num_layers, activation, forward_buffer_ptr,
                              backward_buffer_ptr, grad_inputs_fused_ptr);
      break;
    case 64:
      ffmlp_backward_cuda<64>(grad_ptr, weights_ptr, B, input_dim, output_dim,
                              num_layers, activation, forward_buffer_ptr,
                              backward_buffer_ptr, grad_inputs_fused_ptr);
      break;
    case 128:
      ffmlp_backward_cuda<128>(grad_ptr, weights_ptr, B, input_dim, output_dim,
                               num_layers, activation, forward_buffer_ptr,
                               backward_buffer_ptr, grad_inputs_fused_ptr);
      break;
    case 256:
      ffmlp_backward_cuda<256>(grad_ptr, weights_ptr, B, input_dim, output_dim,
                               num_layers, activation, forward_buffer_ptr,
                               backward_buffer_ptr, grad_inputs_fused_ptr);
      break;
    default:
      throw std::runtime_error{"hidden_dim should in [16, 32, 64, 128, 256]"};
  }

  // printf("[backward] finished backward kernel\n");

  forward_index--;
  backward_index++;

  // calc middle layer's grad_weights
  for (uint32_t i = 0; i < num_layers - 1; i++) {
    uint32_t matrix_index = num_layers - 2 - i;

    cudaEventRecord(events_splitk().at(backward_index), 0);
    cudaStreamWaitEvent(streams_splitk().at(backward_index),
                        events_splitk().at(backward_index), 0);

    fc_multiply_split_k<FullLayerK, false, true, true>(
        streams_splitk().at(backward_index), hidden_dim, B, hidden_dim,
        (backward_buffer_ptr +
         (backward_index - 1) * hidden_dim * B),  // col-major [hidden_dim, B]
        (forward_buffer_ptr +
         forward_index * hidden_dim * B),  // row-major [B, hidden_dim]
        (grad_weights_ptr + hidden_dim * input_dim +
         matrix_index * hidden_dim *
             hidden_dim),  // row-major, [hidden_dim, hidden_dim]
        split_k_factor);

    cudaEventRecord(events_splitk().at(backward_index),
                    streams_splitk().at(backward_index));

    forward_index--;
    backward_index++;
  }

  // calc input layer's grad_weights
  cudaEventRecord(events_splitk().at(backward_index), 0);
  cudaStreamWaitEvent(streams_splitk().at(backward_index),
                      events_splitk().at(backward_index), 0);

  fc_multiply_split_k<FullLayerK, false, true, true>(
      streams_splitk().at(backward_index), hidden_dim, B, input_dim,
      (backward_buffer_ptr +
       (backward_index - 1) * hidden_dim * B),  // col-major [hidden_dim, B]
      inputs_ptr,                               // row-major, [B, input_dim]
      grad_weights_ptr,  // row-major, [hidden_dim, input_dim]
      split_k_factor);

  cudaEventRecord(events_splitk().at(backward_index),
                  streams_splitk().at(backward_index));

  // calc grad_inputs if input_dim != hidden_dim
  if (grad_inputs_fused_ptr == nullptr) {
    fc_multiply<FullLayer, false, false, false>(
        0, input_dim, hidden_dim, B,
        weights_ptr,  // col-major [input_dim, hidden_dim]
        (backward_buffer_ptr +
         (backward_index - 1) * hidden_dim * B),  // col-major [hidden_dim, B]
        grad_inputs_ptr                           // col-major [input_dim, B]
    );
  }

  // All the per-layer split-k matrix multiplications summing over
  // the batch are computed in parallel streams to the actual
  // backpropagation. Here, we need to wait for all of these to complete.
  for (auto& event : events_splitk()) {
    cudaStreamWaitEvent(0, event, 0);
  }

  return {grad_inputs, grad_weights};
}
